package vfile.djl;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications.Classification;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.*;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import com.arcsoft.face.FaceInfo;
import com.arcsoft.face.Rect;

import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

public class DjlImageVisitor{

    Predictor<Image, DetectedObjects> predictor;
    int counter;
    private ImageFactory factory;

    public DjlImageVisitor() throws IOException, ModelException {
        Criteria<Image, DetectedObjects> criteria =
                Criteria.builder()
                        .setTypes(Image.class, DetectedObjects.class)
                        .optArtifactId("ai.djl.mxnet:ssd")
                        .optFilter("backbone","vgg16")
                        .build();
        predictor = ModelZoo.loadModel(criteria).newPredictor();
        counter = 0;
        factory = ImageFactory.getInstance();
    }

    public List<FaceInfo> process(BufferedImage bImg) {

        try {
            int imageWidth = bImg.getWidth();
            int imageHeight = bImg.getHeight();
            Image image = factory.fromImage(bImg);
            DetectedObjects prediction = predictor.predict(image);
            String classStr =
                    prediction
                            .items()
                            .stream()
                            .map(Classification::getClassName)
                            .collect(Collectors.joining(", "));
            System.out.println("Found objects: " + classStr);
            boolean hasPerson =
                    prediction
                            .items()
                            .stream()
                            .anyMatch(
                                    c ->
                                            "person".equals(c.getClassName())
                                                    && c.getProbability() > 0.5);
            List<FaceInfo> faceInfoList = new ArrayList<FaceInfo>();
            List<DetectedObjects.DetectedObject> list = prediction.items();
            for (DetectedObjects.DetectedObject result : list) {
                String className = result.getClassName();
                BoundingBox box = result.getBoundingBox();

                Rectangle rectangle = box.getBounds();
                int x = (int) (rectangle.getX() * imageWidth);
                int y = (int) (rectangle.getY() * imageHeight);
                FaceInfo info = new FaceInfo();
                Rect rect = new Rect();
                rect.setLeft(x);
                rect.setTop(y);
                rect.setRight(x+ (int) (rectangle.getWidth() * imageWidth));
                rect.setBottom(y+(int) (rectangle.getHeight() * imageHeight));
                info.setRect(rect);
                faceInfoList.add(info);
            }
            //image.drawBoundingBoxes(prediction);
            return faceInfoList;
          /*  Path outputFile = Paths.get("out/image-" + counter + ".png");
            try (OutputStream os = Files.newOutputStream(outputFile)) {
                image.save(os, "png");
            }
            counter++;*/
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
}
